# Copyright (c) OpenMMLab. All rights reserved.
import torch


def filter_outside_objs(gt_bboxes_list, gt_labels_list, gt_bboxes_3d_list,
                        gt_labels_3d_list, centers2d_list, img_metas):
    """Function to filter the objects label outside the image.

    Args:
        gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
            each has shape (num_gt, 4).
        gt_labels_list (list[Tensor]): Ground truth labels of each box,
            each has shape (num_gt,).
        gt_bboxes_3d_list (list[Tensor]): 3D Ground truth bboxes of each
            image, each has shape (num_gt, bbox_code_size).
        gt_labels_3d_list (list[Tensor]): 3D Ground truth labels of each
            box, each has shape (num_gt,).
        centers2d_list (list[Tensor]): Projected 3D centers onto 2D image,
            each has shape (num_gt, 2).
        img_metas (list[dict]): Meta information of each image, e.g.,
            image size, scaling factor, etc.
    """
    bs = len(centers2d_list)

    for i in range(bs):
        centers2d = centers2d_list[i].clone()
        img_shape = img_metas[i]['img_shape']
        keep_inds = (centers2d[:, 0] > 0) & \
            (centers2d[:, 0] < img_shape[1]) & \
            (centers2d[:, 1] > 0) & \
            (centers2d[:, 1] < img_shape[0])
        centers2d_list[i] = centers2d[keep_inds]
        gt_labels_list[i] = gt_labels_list[i][keep_inds]
        gt_bboxes_list[i] = gt_bboxes_list[i][keep_inds]
        gt_bboxes_3d_list[i].tensor = gt_bboxes_3d_list[i].tensor[keep_inds]
        gt_labels_3d_list[i] = gt_labels_3d_list[i][keep_inds]


def get_centers2d_target(centers2d, centers, img_shape):
    """Function to get target centers2d.

    Args:
        centers2d (Tensor): Projected 3D centers onto 2D images.
        centers (Tensor): Centers of 2d gt bboxes.
        img_shape (tuple): Resized image shape.

    Returns:
        torch.Tensor: Projected 3D centers (centers2D) target.
    """
    N = centers2d.shape[0]
    h, w = img_shape[:2]
    valid_intersects = centers2d.new_zeros((N, 2))
    a = (centers[:, 1] - centers2d[:, 1]) / (centers[:, 0] - centers2d[:, 0])
    b = centers[:, 1] - a * centers[:, 0]
    left_y = b
    right_y = (w - 1) * a + b
    top_x = -b / a
    bottom_x = (h - 1 - b) / a

    left_coors = torch.stack((left_y.new_zeros(N, ), left_y), dim=1)
    right_coors = torch.stack((right_y.new_full((N, ), w - 1), right_y), dim=1)
    top_coors = torch.stack((top_x, top_x.new_zeros(N, )), dim=1)
    bottom_coors = torch.stack((bottom_x, bottom_x.new_full((N, ), h - 1)),
                               dim=1)

    intersects = torch.stack(
        [left_coors, right_coors, top_coors, bottom_coors], dim=1)
    intersects_x = intersects[:, :, 0]
    intersects_y = intersects[:, :, 1]
    inds = (intersects_x >= 0) & (intersects_x <=
                                  w - 1) & (intersects_y >= 0) & (
                                      intersects_y <= h - 1)
    valid_intersects = intersects[inds].reshape(N, 2, 2)
    dist = torch.norm(valid_intersects - centers2d.unsqueeze(1), dim=2)
    min_idx = torch.argmin(dist, dim=1)

    min_idx = min_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 2)
    centers2d_target = valid_intersects.gather(dim=1, index=min_idx).squeeze(1)

    return centers2d_target


def handle_proj_objs(centers2d_list, gt_bboxes_list, img_metas):
    """Function to handle projected object centers2d, generate target
    centers2d.

    Args:
        gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
            shape (num_gt, 4).
        centers2d_list (list[Tensor]): Projected 3D centers onto 2D image,
            shape (num_gt, 2).
        img_metas (list[dict]): Meta information of each image, e.g.,
            image size, scaling factor, etc.

    Returns:
        tuple[list[Tensor]]: It contains three elements. The first is the
        target centers2d after handling the truncated objects. The second
        is the offsets between target centers2d and round int dtype
        centers2d,and the last is the truncation mask for each object in
        batch data.
    """
    bs = len(centers2d_list)
    centers2d_target_list = []
    trunc_mask_list = []
    offsets2d_list = []
    # for now, only pad mode that img is padded by right and
    # bottom side is supported.
    for i in range(bs):
        centers2d = centers2d_list[i]
        gt_bbox = gt_bboxes_list[i]
        img_shape = img_metas[i]['img_shape']
        centers2d_target = centers2d.clone()
        inside_inds = (centers2d[:, 0] > 0) & \
            (centers2d[:, 0] < img_shape[1]) & \
            (centers2d[:, 1] > 0) & \
            (centers2d[:, 1] < img_shape[0])
        outside_inds = ~inside_inds

        # if there are outside objects
        if outside_inds.any():
            centers = (gt_bbox[:, :2] + gt_bbox[:, 2:]) / 2
            outside_centers2d = centers2d[outside_inds]
            match_centers = centers[outside_inds]
            target_outside_centers2d = get_centers2d_target(
                outside_centers2d, match_centers, img_shape)
            centers2d_target[outside_inds] = target_outside_centers2d

        offsets2d = centers2d - centers2d_target.round().int()
        trunc_mask = outside_inds

        centers2d_target_list.append(centers2d_target)
        trunc_mask_list.append(trunc_mask)
        offsets2d_list.append(offsets2d)

    return (centers2d_target_list, offsets2d_list, trunc_mask_list)
